
library(FNN)
library(glmnet)
library(tensorflow)
quantcut <- function (x, q = 4, na.rm = TRUE, ...) 
{
  if (length(q) == 1) 
    q <- seq(0, 1, length.out = q + 1)
  quant <- quantile(x, q, na.rm = na.rm)
  dups <- duplicated(quant)
  if (any(dups)) {
    flag <- x %in% unique(quant[dups])
    retval <- ifelse(flag, paste("[", as.character(x), "]", 
                                 sep = ""), NA)
    uniqs <- unique(quant)
    reposition <- function(cut) {
      flag <- x >= cut
      if (sum(flag) == 0) 
        return(cut)
      else return(min(x[flag], na.rm = na.rm))
    }
    newquant <- sapply(uniqs, reposition)
    retval[!flag] <- as.character(cut(x[!flag], breaks = newquant, 
                                      include.lowest = TRUE, ...))
    levs <- unique(retval[order(x)])
    retval <- factor(retval, levels = levs)
    mkpairs <- function(x) sapply(x, function(y) if (length(y) == 
                                                     2) 
      y[c(2, 2)]
      else y[2:3])
    pairs <- mkpairs(strsplit(levs, "[^0-9+\\.\\-]+"))
    rownames(pairs) <- c("lower.bound", "upper.bound")
    colnames(pairs) <- levs
    closed.lower <- rep(F, ncol(pairs))
    closed.upper <- rep(T, ncol(pairs))
    closed.lower[1] <- TRUE
    for (i in 2:ncol(pairs)) if (pairs[1, i] == pairs[1, 
                                                      i - 1] && pairs[1, i] == pairs[2, i - 1]) 
      closed.lower[i] <- FALSE
    for (i in 1:(ncol(pairs) - 1)) if (pairs[2, i] == pairs[1, 
                                                            i + 1] && pairs[2, i] == pairs[2, i + 1]) 
      closed.upper[i] <- FALSE
    levs <- ifelse(pairs[1, ] == pairs[2, ], pairs[1, ], 
                   paste(ifelse(closed.lower, "[", "("), pairs[1, ], 
                         ",", pairs[2, ], ifelse(closed.upper, "]", ")"), 
                         sep = ""))
    levels(retval) <- levs
  }
  else retval <- cut(x, quant, include.lowest = TRUE, ...)
  return(retval)
}

nn_match_fxn <- function(features_mat, treatment_indicator, Y_obs, nMatch){ 
  features_t <- as.matrix( features_mat[treatment_indicator==1, ] )  
  features_c <- as.matrix( features_mat[treatment_indicator==0, ] )  
  match_indices <- c(knnx.index(data = as.matrix(features_c), 
                                query = as.matrix(features_t), k=nMatch))
  Ybar_C_adj <- mean(Y_obs[treatment_indicator==0][match_indices])
  Ybar_T_adj <- mean(Y_obs[treatment_indicator==1])
  return_val <- Ybar_T_adj - Ybar_C_adj
  return( return_val )
} 

optimal_match_fxn <- function( features_mat, treatment_indicator, 
                               nMatch = NULL, 
                               Y_obs, PropKeep_caliper){ 
  causal_reduce_list <- causal_reduce(treatment_indicator = treatment_indicator, 
                                      Y_obs = Y_obs, 
                                      features_mat = features_mat, 
                                      PropKeep_caliper = PropKeep_caliper) 
  features_mat <- as.matrix(  causal_reduce_list$features_mat )  
  Y_obs <- causal_reduce_list$Y_obs
  treatment_indicator <- causal_reduce_list$treatment_indicator
  data_temp <- causal_reduce_list$data_temp
  dist_match <- causal_reduce_list$dist_match
  
  full_match_list <-   fullmatch(dist_match,
                                 min.controls = 1, 
                                 max.controls = Inf, 
                                 #omit.fraction = 0, 
                                 data = data_temp)
  full_match_list <- summary(  full_match_list )
  my_match <- full_match_list$thematch
  treatment_indicator <- treatment_indicator[names(my_match)]
  features_mat <- features_mat[names(my_match),]
  Y_obs <- Y_obs[names(my_match)]
  
  group_table <- tapply(as.factor(treatment_indicator), my_match, function(x){ 
    my_tab <- table(x)
    my_tab[1] / my_tab[2]
  } )
  
  if(T == F){
    id_vec <- as.numeric( as.factor(full_match_list$thematch))
    plot(features_mat[,1], 
         features_mat[,2],
         cex = 2, 
         pch = id_vec, 
         col = treatment_indicator + 1 )
    #lwd = 3 * treatment_indicator + 1)
  }
  #group_table <- group_table / sum(  group_table )
  
  control_match_group <- my_match[treatment_indicator == 0]
  treated_match_group <- my_match[treatment_indicator == 1]
  unique_values_vec <- unique( my_match )
  
  n_t <- sum(treatment_indicator)
  inner_tau <- tapply(1:length(my_match), 
                      my_match, 
                      function(x){ 
                        Y_red <- Y_obs[ my_match %in% my_match[x][1]  ] 
                        treat_red <- treatment_indicator[ my_match %in% my_match[x][1]  ] 
                        tau <- sum(treat_red) /n_t  * (  mean(Y_red[treat_red == 1]) - mean(Y_red[treat_red == 0]) )
                      })
  return_val <- sum( inner_tau )  
  if(T == F){ 
    X_c <- features_mat[treatment_indicator == 0,]
    X_t <- features_mat[treatment_indicator == 1,]
    rawXbarDiff <- sum( (colMeans(X_t) - colMeans(X_c))^2 ) 
    inner_imbalance <- tapply(1:length(my_match), 
                              my_match, 
                              function(x){ 
                                X_red <- features_mat[ my_match %in% my_match[x][1],  ] 
                                treat_red <- treatment_indicator[ my_match %in% my_match[x][1]  ] 
                                X_c_red <- X_red[treat_red==0,]; if(class(X_c_red) == "matrix"){X_c_red <- colMeans(X_c_red)}
                                X_t_red <- X_red[treat_red==1,]; if(class(X_t_red) == "matrix"){X_t_red <- colMeans(X_t_red)}
                                val1 <- sum(treat_red) / n_t * ( X_t_red  )
                                val2 <- sum(treat_red) / n_t * ( X_c_red  )
                                return_tau <- c(val1, val2)
                              })
    inner_imbalance <- do.call(rbind, inner_imbalance)
    inner_imbalance <- colSums(inner_imbalance) 
    #colMeans(X_t) ==  inner_imbalance[c(1:ncol(X))]
    Xbar_c_adj <- inner_imbalance[-c(1:ncol(X))]
    XbarDiff <- sum( (colMeans(X_t) - Xbar_c_adj )^2) 
    
    #knn comparison 
    features_t <- features_mat[treatment_indicator==1, ]
    features_c <- features_mat[treatment_indicator==0, ]
    I_embeddings <- c(knnx.index(data = as.matrix(features_c), 
                                 query = as.matrix(features_t), k=nMatch))
    tau_knn <- mean(Y_obs[treatment_indicator==1]) - mean(Y_obs[treatment_indicator==0][I_embeddings])
    XbarC_knn <- apply(X_c, 2, function(x) sum(x[I_embeddings] ) / length(x[I_embeddings])  )
    XbarDiff_knn <- sum( (colMeans(X_t) - XbarC_knn)^2) 
  } 
  return( return_val )
} 

semisynthetic_testor <- function(full_X, 
                                 Y_obs, 
                                 true_treatment_indicator, 
                                 typePropensity = "lasso", 
                                 synthetic_treatment_effect = NULL, 
                                 n_replications = 5, 
                                 length_prob_seq = 1, 
                                 max_n = Inf, 
                                 plot_propensity = F, 
                                 saveName ){
  
  if(typePropensity == "lasso"){ 
    my_glmnet <-  cv.glmnet(x = as.matrix(full_X), 
                            y = as.matrix(true_treatment_indicator), 
                            nfolds = 10, 
                            nlambda = 100, 
                            alpha = 0, #ridge 
                            family = 'binomial')
    propensity_scores <- predict(my_glmnet, newx = as.matrix(full_X), s = "lambda.min", 
                                 type = "response")
  }
  if(typePropensity == "forest"){ 
    library(randomForest)
    myForest <- randomForest(x = full_X, 
                             y = as.factor( true_treatment_indicator), 
                             ntree = 50, 
                             #mtry = floor(ncol(full_X)/2.5),   
                             do.trace = F) 
    propensity_scores <- predict( myForest, newx =  full_X, type = "prob")[,2]
    propensity_scores[is.na(propensity_scores)] <- mean(propensity_scores, na.rm = T)
  }
  if(typePropensity == "OnY"){ 
    my_glmnet <-  cv.glmnet(x = as.matrix(full_X), 
                            y = as.matrix(Y_obs), 
                            nfolds = 10, 
                            nlambda = 100, 
                            alpha = 0, #ridge 
                            family = 'gaussian')
    Y_obs_hat <- predict(my_glmnet, s = "lambda.min", as.matrix(full_X) )
    propensity_scores <- rank(Y_obs_hat) / max(rank(Y_obs_hat))
  }
  propensity_scores[propensity_scores > 0.99] <- 0.99
  propensity_scores[propensity_scores < 0.01] <- 0.01
  
  X_c_orig <- full_X[true_treatment_indicator==0,]
  Y_obs <- Y_obs[true_treatment_indicator==0]
  propensity_scores <- propensity_scores[true_treatment_indicator==0]
  
  drop_logical <- (Y_obs < 3*summary(Y_obs)[5] & 
                     Y_obs < summary(Y_obs)[2] - 4 * sd(Y_obs, na.rm = T) ) 
  X_c_orig <- X_c_orig[!drop_logical, ]
  Y_obs <- Y_obs[!drop_logical]
  propensity_scores <- propensity_scores[!drop_logical]
  my_n <- nrow( X_c_orig )  
  
  error_df <- c()
  for(iter_ii in 1:n_replications) {
    X_synthetic <- X_c_orig
    Y_obs_synthetic <- Y_obs
    
    #treatment_indicator_synthetic <- rbinom(my_n, 1, prob = propensity_scores)
    treatment_indicator_synthetic <- rep(0, times = length( propensity_scores )  )
    treatment_indicator_synthetic[sample(1:length(propensity_scores), 
                                         size = round(0.5 * length(propensity_scores)), 
                                         prob = propensity_scores) ] <- 1
    if(length(unique(Y_obs)) > 10){ Y_obs_synthetic[ treatment_indicator_synthetic == 1 ] <- Y_obs[ treatment_indicator_synthetic == 1 ] + synthetic_treatment_effect } 
    if(length(unique(Y_obs)) <= 10){ 
      prob_t <- table(Y_obs[ treatment_indicator_synthetic == 1 ])
      prob_n <- prob_t <- prob_t/sum(prob_t)
      prob_n[2] <- prob_n[2] + synthetic_treatment_effect
      prob_n[-2] <- prob_n[-2] - synthetic_treatment_effect/length(prob_n[-2])
      new_t <- round( prob_n[2] * sum(treatment_indicator_synthetic) - prob_t[2] * sum(treatment_indicator_synthetic))
      temp_d <- Y_obs[ treatment_indicator_synthetic == 1 ]
      temp_d[sample(which(temp_d==0), new_t) ] <- 1
      Y_obs_synthetic[ treatment_indicator_synthetic == 1 ] <-  temp_d
    } 
    actualized_synthetic_treatment_effect <- (tapply(Y_obs_synthetic, treatment_indicator_synthetic, mean)[2] - 
                                                tapply(Y_obs_synthetic, treatment_indicator_synthetic, mean)[1]) - 
      (tapply(Y_obs, treatment_indicator_synthetic, mean)[2] - 
         tapply(Y_obs, treatment_indicator_synthetic, mean)[1])
    
    if(nrow(X_synthetic) > max_n){ 
      use_indices <- sample(1:nrow(X_synthetic), max_n)
      X_synthetic <- X_synthetic[use_indices,]
      Y_obs_synthetic <- Y_obs_synthetic[use_indices]
      treatment_indicator_synthetic <- treatment_indicator_synthetic[use_indices]
    } 
    
    X_synthetic <- X_synthetic[,apply(X_synthetic[treatment_indicator_synthetic==0,], 2, function(x) sd(x)) >  0 & 
                                 apply(X_synthetic[treatment_indicator_synthetic==1,], 2, function(x) sd(x)) >  0 ]
    row.names( X_synthetic) <- 1:nrow(  X_synthetic ) 
    X_synthetic <- as.matrix(  X_synthetic )
    
    X_synthetic <- scale(X_synthetic)
    X_synthetic <- X_synthetic[,!is.na(colSums(X_synthetic))]
    
    res_ii <- try( battery_fxn( X = X_synthetic, 
                                treatment_indicator = treatment_indicator_synthetic, 
                                Y_obs = Y_obs_synthetic, 
                                lowDim = min(5, ncol(X_synthetic )-1), 
                                mahal_match = F, m = 20, nMatch = 3, 
                                PropKeep_caliper = 1, quiet = F), T)
    res_ii <- abs(actualized_synthetic_treatment_effect - cbind(res_ii)  ) 
    error_df <- rbind(error_df, res_ii)
  }
  
  #error_df <- do.call(rbind, outer_loop)
  write.csv(file = sprintf("./semisynthetic_exp_%s.csv", saveName), error_df)
} 




battery_fxn <- function(X, treatment_indicator, Y_obs, 
                        embeddings = T, psm = T, mahal_match  = T, 
                        predictive_mean_match = T, 
                        pca_match = T, euclid_match = T, readme2_match = T, 
                        lowDim = NULL, nMatch = 3, PropKeep_caliper = 1, 
                        match_method = "nearest", 
                        m = 50, my_caliper = NULL, 
                        quiet = F)
{ 
  readme2_est <- predictive_mean_est <- random_est <- naive_est <- psm_est <- euclid_est <- pca_est <- embeddings_est <- NA
  
  #setup data 
  colnames(X) <- paste("V", 1:ncol(X), sep = "")
  X_c <- X[treatment_indicator == 0,]
  Yc_obs <- Y_obs[treatment_indicator == 0]
  Yt_obs <- Y_obs[treatment_indicator == 1]
  
  sampleSize <- length(c(Y_obs))
  if(is.null(lowDim)){ lowDim <- max(floor(log(sampleSize)), 6) }    # dimension of the embedding
  Ybar_T <- mean(Y_obs[treatment_indicator==1]) 
  Ybar_C <- mean(Y_obs[treatment_indicator==0])
  
  Xbar_t <- colMeans(as.matrix(X[treatment_indicator==1,]))
  X_c <- X[treatment_indicator == 0,]
  
  n_c <- sum(treatment_indicator==0)
  n_t <- sum(treatment_indicator==1)
  
  #naive estimator 
  naive_est <- Ybar_T - Ybar_C  
  
  if(quiet == F){ print("starting...") } 
  
  #euclidean matching 
  if(euclid_match == T){ 
    if(match_method == "nearest"){  euclid_est <- nn_match_fxn(features_mat = X, treatment_indicator = treatment_indicator, 
                                                               nMatch = nMatch, Y_obs = Y_obs )   }  
    if(match_method == "optimal"){ euclid_est <- optimal_match_fxn(features_mat = X, treatment_indicator = treatment_indicator, 
                                                                   Y_obs = Y_obs, nMatch = nMatch)   }
    if(quiet == F){  print("done euclid") } 
  } 
  
  #random matching 
  RandomIndices <- sample(1:n_c, n_t, replace = T) 
  Ybar_C_Random <- mean(Y_obs[treatment_indicator==0][RandomIndices])
  random_est <- Ybar_T - Ybar_C_Random
  
  #psm matching 
  if(psm == T){ 
    #propensity_type <- "binomial"
    propensity_type <- "gaussian"
    my_glmnet <- eval(parse(text = sprintf("cv.glmnet(x = as.matrix(X), 
                           y = as.matrix(treatment_indicator), 
                           nfolds = 5, 
                           nlambda = 100, 
                           alpha = 0, #ridge 
                           family = '%s')", propensity_type )  )) 
    features_PSM_orig <- features_PSM <- predict(my_glmnet, newx = X, type = "response")
    if(propensity_type == "binomial"){ features_PSM <-  log(features_PSM / (1 - features_PSM) ) } 
    
    if(match_method == "nearest"){ psm_est <- nn_match_fxn(features_mat = features_PSM, treatment_indicator = treatment_indicator, 
                                                           nMatch = nMatch, Y_obs = Y_obs) } 
    if(match_method == "optimal"){ psm_est <- optimal_match_fxn(features_mat = features_PSM, treatment_indicator = treatment_indicator, 
                                                                Y_obs = Y_obs, nMatch = nMatch) }
    if(quiet == F){  print("done psm") }
  }

  #PCA matching 
  if(pca_match == T){ 
    features_PCA <- as.matrix( predict(prcomp(X, center = T, scale = T))[,1:lowDim] )  
    if(match_method == "nearest"){  pca_est <-  nn_match_fxn(features_mat = features_PCA, treatment_indicator = treatment_indicator, 
                                                             nMatch = nMatch, Y_obs = Y_obs) } 
    if(match_method == "optimal"){ pca_est <- optimal_match_fxn(features_mat = features_PCA, treatment_indicator = treatment_indicator, 
                                                                Y_obs = Y_obs,nMatch = nMatch,  
                                                                PropKeep_caliper = PropKeep_caliper)  } 
    if(quiet == F){  print("done pca")  } 
  } 
  
  #embeddings estimator 
  if(embeddings == T){ 
    if(is.null(lowDim)){ lowDim <- max(floor(log(sampleSize)), 5) }    # dimension of the embedding
    tauhat_new_vec <-  tauhat_adj_vec <- rep(NA, times = m)
    
    indices_tracker <- rep(0, times = length(Yc_obs))
    names(indices_tracker) <- 1:length(indices_tracker)
    for(internal_j in 1:m){ 
      #print( internal_j ) 
      profiles <- as.data.frame(X) # selecting columns "region" to "rec_type"
      which_use <- which(apply(profiles, 2, function(x) length(unique(x)))<5)
      for(ia in which_use){
        profiles[,ia] <- as.factor(as.character(profiles[,ia]))
      }
      # embed the categorical variables
      categorical <- sapply(profiles, is.factor)
      categoricalEmbedding <- matrix(0, nrow=sampleSize, ncol=lowDim)
      for(d in 1:lowDim){
        for(i in names(profiles[categorical])){
          levels(profiles[[i]]) <- rnorm(nlevels(profiles[[i]]),mean=0,sd=1)  # convert levels to random numbers
          categoricalEmbedding[,d] <- categoricalEmbedding[,d] + as.numeric(as.matrix(profiles[[i]]))
        }
      }
      
      # embed the remaining variables (logical and continuous)
      ndims <- length(profiles[!categorical])
      randomProjectionMatrix <- matrix(rnorm(ndims*lowDim,mean=0,sd=1), ndims, lowDim)
      
      continuousEmbedding <- as.matrix(profiles[!categorical]) %*% randomProjectionMatrix
      
      features <- categoricalEmbedding + continuousEmbedding
    
      tauhat_adj_vec[internal_j] <- nn_match_fxn(features_mat = features, treatment_indicator = treatment_indicator, 
                                                 nMatch = nMatch, Y_obs = Y_obs)
    } 
    embeddings_est <- median( tauhat_adj_vec )
    
    if(quiet == F){  print("done embeddings") } 
  } 

  #readme2 matching 
  if(readme2_match == T){ 
    if(length(unique(Y_obs))>=10){CatFactor_vec <- as.numeric(  quantcut(Y_obs[treatment_indicator == 0],q = round(max(2,min(sum(treatment_indicator == 0)/100,15))) ) )}
    if(length(unique(Y_obs))<10){CatFactor_vec <- Y_obs[treatment_indicator == 0] }
    X_scaled <- t((t(X)-colMeans(X[treatment_indicator == 0,])) / apply(X[treatment_indicator == 0,], 2, sd))
    labeled_dat <- X_scaled[treatment_indicator == 0,]
    unlabeled_dat <- X_scaled[treatment_indicator == 1,]
    
    tau_readme2_outer_alt <- tau_readme2_outer <- rep(NA, times = 20)
    NPassin_tf <- ncol(X)
    NCat_tf <- as.integer( length(unique( CatFactor_vec ) ) )
    tf_batch_size_byCat <- as.integer( 34  ) 
    eval(parse(text = tf_source_code ))
    environment(tf_est_fxn) <- getEnvOf( "NPassin_tf" )
    for(aj in 1:length(tau_readme2_outer)){ 
      X_scaled_red <- try(tf_est_fxn(in_dvm_labeled = labeled_dat, 
                                     in_dvm_unlabeled = labeled_dat, 
                                     in_category_vec_labeled = CatFactor_vec, 
                                     sgd_iters = 300,plot_results = F, 
                                     get_features_only = T, get_features_only_input = X_scaled), T)
      X_scaled_red_treated <- X_scaled_red[treatment_indicator==1,]
      X_scaled_red_control <- X_scaled_red[treatment_indicator==0,]
      
      tau_readme2_outer[aj] <- nn_match_fxn(features_mat = X_scaled_red, treatment_indicator = treatment_indicator, nMatch = nMatch, Y_obs = Y_obs) 
    }   
    readme2_est = median(tau_readme2_outer, na.rm = T) 
  }

  
  #predictive mean matching 
  if(predictive_mean_match == T){ 
    Y_0_obs <- Y_obs[treatment_indicator == 0]
    Y_1_obs <- Y_obs[treatment_indicator == 1]
    X_0 <- X[treatment_indicator==0,];X_1 <- X[treatment_indicator==1,]
    
    
    my_glmnet <- glmnet::cv.glmnet(x = as.matrix(X_0), y = as.matrix(Y_0_obs), family = "gaussian")
    Y_0_hat_lasso <- c(predict( my_glmnet, s = "lambda.min", newx = X))
    my_coefs <- coef(my_glmnet, s = "lambda.1se")[-1,]
    coef_mat_iij <- as.matrix(my_coefs)
    XLow <- X %*%  coef_mat_iij
    XLow_0 <- XLow[treatment_indicator == 0,]
    Y_hat_0_iij <- XLow
  
    predictive_mean_est <- nn_match_fxn(features_mat = Y_hat_0_iij, treatment_indicator = treatment_indicator, 
                                            nMatch = nMatch, Y_obs = Y_obs) 
  } 
  
  return_contents <- data.frame(
             naive_est = c(naive_est), 
             psm_est = c(psm_est), 
             euclid_est = c(euclid_est), 
             pca_est = c(pca_est), 
             random_est = c(random_est), 
             embeddings_est = c(embeddings_est),
             predictive_mean_est = c(predictive_mean_est), 
             readme2_est = c(readme2_est)) 
  return( return_contents )
}


dist_fxn <- function(X, treatment_indicator, Y_obs, n_match = 1, make_double = F, selfcode = F, A = NULL){
  if(is.null(A)){ 
    euclidean <- T 
  }
  row.names(X) <- 1:nrow(X)
  X_t <- X[treatment_indicator == 1,]
  X_c <- X[treatment_indicator == 0,]
  n_treat <- nrow(X_t)
  n_control <- nrow(X_c)
  
  if(selfcode == T){ 
    go_string <- sample(1:n_treat, size = n_treat, replace = F)
    match_mat <- matrix(NA, nrow = n_treat, ncol = n_match * 2+1)
    colnames(match_mat) <- c("TreatedIndex", 
                             sprintf("ControlMatch_%s", 1:n_match), 
                             sprintf("Dist_%s", 1:n_match))
    for(i in 1:nrow(X_t)){ 
      if(i %% 100 == 0){ print (sprintf("inner %s", i) )}
      t_val <- X_t[go_string[i],]
      diff_mat <- t(apply(X_c, 1, function(x) (t_val - x) ) )
      if(euclidean == T){ dist_vec <- apply(diff_mat, 1, function(x) sum(x^2)) }
      if(euclidean == F){ dist_vec <- diag( diff_mat %*% A %*% t(diff_mat) ) } 
      match_res_control_indexed <- order(dist_vec, decreasing = F)[1:n_match]
      match_res <- match_res_control_indexed#row.names(X_c)[match_res_control_indexed]
      dist_res <- dist_vec[match_res]
      match_mat[i,] <- c(go_string[i], match_res, dist_res)
    }
    
    match_mat_double <- NULL 
    make_double <- T
    if(make_double == T){ 
      match_mat_double <- as.data.frame(  matrix(NA, nrow = n_treat, ncol = n_match * 2+1) )  
      colnames(match_mat_double) <- c("TreatedIndex", 
                               sprintf("TreatedMatch_%s", 1:n_match), 
                               sprintf("Dist_%s", 1:n_match))
      for(i in 1:nrow(X_t)){ 
        if(i %% 100 == 0){ print (sprintf("inner %s", i) )}
        t_val <- X_t[go_string[i],]
        comparison_mat_i <- X_t[-go_string[i],]
        diff_mat <- t(apply(comparison_mat_i, 1, function(x) (t_val - x) ) )
        if(euclidean == T){ dist_vec <- apply(diff_mat, 1, function(x) sum(x^2)) }
        if(euclidean == F){ dist_vec <- diag( diff_mat %*% A %*% t(diff_mat) ) } 
        match_res_control_indexed <- order(dist_vec, decreasing = F)[1:n_match]
        match_res <- names(dist_vec)[match_res_control_indexed]
        dist_res <- dist_vec[match_res]
        match_mat_double[i,] <- c(go_string[i], names(dist_res), dist_res)
      } 
      match_mat_double <- apply(match_mat_double, 2, function(x) as.numeric(as.character(x)) ) 
    }
    return_list <- list(match_mat = as.data.frame(match_mat) , 
                         match_mat_double = as.data.frame(match_mat_double) ) 
  } 
  
  if(selfcode == F){ 
    require("FNN")
    match_mat_to_c <- as.matrix( c(knnx.index(as.matrix(X_c),
                                 as.matrix(X_t), k=n_match)) ) 
    dist_mat_to_c <- matrix(NA, nrow = nrow(match_mat_to_c), ncol = ncol(match_mat_to_c))
    match_mat_to_t <- dist_mat_to_t <- matrix(NA, nrow = nrow(X_t), ncol = n_match)
    n_covars <- ncol(X_t)
    handler_max <- apply(X_t, 2, function(x) 100 * max(abs(x)) ) 
    for(innermost_i in 1:nrow(X_t)){ 
      
      #find and handle distances to t
      temp_X_t <- X_t 
      temp_X_t[innermost_i,] <- handler_max
      match_mat_to_t[innermost_i,] <- c(c(knnx.index(as.matrix(temp_X_t),
                                t(X_t[innermost_i,]), k=n_match))) 
      compare_mat_to_t <- X_t[match_mat_to_t[innermost_i,] ,]
      if(n_match == 1){ compare_mat_to_t <- t(compare_mat_to_t)}
      dist_mat_to_t[innermost_i,] <- apply(compare_mat_to_t, 1, function(x) sum( (x - X_t[innermost_i,])^2) ) 
      
      #handle distances to c 
      compare_mat <- X_c[match_mat_to_c[innermost_i,],]
      if(n_match == 1){ compare_mat <- t(compare_mat) }
      dist_mat_to_c[innermost_i,] <- apply(compare_mat, 1, function(x) sum( (x - X_t[innermost_i,])^2) ) 
      
      return_list <- list(dist_mat_to_c = dist_mat_to_c, 
                          match_mat_to_c = match_mat_to_c, 
                          match_mat_to_t = match_mat_to_t, 
                          dist_mat_to_t = dist_mat_to_t)
    }
  } 
  return( return_list ) 
}


data_gen_fxn <- function(n = 1000, 
                         p, 
                         Sigma = diag(p),
                         MisspecificationProp, 
                         n_covars_important, 
                         n_covars_important_propensity, 
                         response_sd = 1, 
                         sd_coef = 1,
                         clusterX = F, 
                         coef_cor_factor = NULL, 
                         sd_coef_propensity = 1,
                         nonLinear_outcome = F, 
                         nonLinear_propensity = F, 
                         fixed_treated_prop = NULL, 
                         treatment_effect){ 
  if(clusterX == F){ 
    X <- rmvnorm(n = n, mean = rep(0, p), sigma = Sigma)
  } 
  if(clusterX == T){
    library(MCMCpack)
    my_k <- 5
    n_per_k <- round(n / my_k)
    X <- c() 
    for(i_new in 1:my_k){ 
      mySigma <- riwish(v = p+1, S = diag(p))
      myMean <- rnorm(p, 0, sd =  1 * 1 / sqrt(n_per_k))
      X_i_new <- rmvnorm(n = n_per_k, mean = myMean, sigma = mySigma)
      X <- rbind(X, X_i_new)
    } 
  } 
  row.names(X) <- 1:nrow(X)
  my_colnames <- sprintf("V%s", 1:p)
  colnames( X ) <- my_colnames
  
  if(!is.null(coef_cor_factor)){ 
    library(mvtnorm)
    cor_mat_selection <- diag(p)
    cor_mat_outcome <- diag(p )

  if(T == T){  
    nr=2*p; #Number of rows
    nc=2*p; #Number of columns
    CondNumb=sqrt(coef_cor_factor) #Desired condition number
    A=matrix(rnorm(nr*nc), nrow = nr, ncol = nc)
    svd_list <- svd(A)
    svd_list$d <- seq(CondNumb, 1, length.out = min(nr, nc))
    A=as.matrix(svd_list$u) %*% diag(svd_list$d) %*% t(svd_list$v)
    coef_cor_mat <- t(A) %*% A
  }
    
    if(T == F){ 
      upper_mat <- cbind(cor_mat_selection, matrix(coef_cor_factor, nrow = nrow(cor_mat_selection), ncol = ncol(cor_mat_outcome)))
      lower_mat <- cbind(matrix(coef_cor_factor, nrow = nrow(cor_mat_outcome), ncol = ncol(cor_mat_selection)), cor_mat_outcome)
      coef_cor_mat <- rbind(upper_mat, lower_mat)
      
      coef_cor_mat <- matrix(coef_cor_factor, nrow = 2 * p , ncol = 2 * p )
      diag(coef_cor_mat) <- 1
    } 
    
    
    if( min(eigen(coef_cor_mat)$values) < 0 ){stop("Infeasible Sigma")}
    full_coefs <- rmvnorm(1, mean = rep(0, times = 2*p), sigma = coef_cor_mat )
    
    my_coef_propensity_model <- full_coefs[1:p]
    my_coef_propensity_model[ -c(1:n_covars_important_propensity) ] <- 0;
    
    my_coef <- full_coefs[1:p]
    my_coef[ -c(1:n_covars_important) ] <- 0;
  } 
  
  if(is.null(coef_cor_factor)){ 
    my_coef_propensity_model <- rnorm(p, mean = 0, sd = sd_coef_propensity)
    my_coef_propensity_model[ -c(1:n_covars_important_propensity) ] <- 0;
    
    my_coef <- rnorm(n = p, mean = 0, sd = sd_coef)
    my_coef[ -c(1:n_covars_important) ] <- 0 
  } 
  
  if(nonLinear_propensity == T){ 
      my_BinaryLinearMat_propensity <- 
        tapply(1:sum(abs(my_coef_propensity_model)>0), 1:sum(abs(my_coef_propensity_model)>0), 
               function(index){ 
                 x <- X[,index]
                 cutpoint <- sample(x, 1 )
                 res <- rep(0, times = length(x))
                 res[x > cutpoint] <- sign(my_coef_propensity_model[index]) * abs(x[x > cutpoint])^abs(my_coef_propensity_model[index])
                 res <- list( res  )
                 return( res )  
               })   
      my_BinaryLinearMat_propensity <- do.call(cbind, my_BinaryLinearMat_propensity)
      treatment_propensity <- rowSums( my_BinaryLinearMat_propensity )
  }
  if(nonLinear_propensity == F){ 
    treatment_propensity <- c( X %*% my_coef_propensity_model +rnorm(nrow(X), 0, response_sd) )  
  } 
  treatment_propensity <- 1 / (1 + exp(-treatment_propensity))
  if(is.null(fixed_treated_prop)){
    treatment_indicator <- rbinom(n, size = 1, prob = treatment_propensity)
  } 
  if(!is.null(fixed_treated_prop)){ 
    treatment_indicator <- rep(0, times = n)
    treated_indices <- sample(1:n, round(n * fixed_treated_prop), prob = treatment_propensity)
    treatment_indicator[treated_indices] <- 1
  }

  if(nonLinear_outcome == T){ 
    toNonlinear_indices <- sample(which(abs(my_coef) > 0), 
                                  floor(MisspecificationProp*length(which(abs(my_coef) > 0))))
    my_BinaryLinearMat <- 
       tapply(1:length(my_coef), 1:length(my_coef), 
             function(index){ 
      x <- X[,index]
      res <- x * my_coef[index]
      det_val <- runif(1)
      if(!index %in% toNonlinear_indices){
        res <- x * my_coef[index]
      } 
      if(index %in% toNonlinear_indices){
        det_value_inner <- runif(1)
        if(det_value_inner > 0.33){ 
          cutpoint <- sample(x, 1 )
          res <- rep(cutpoint, times = length(x))
          res[x > cutpoint] <- x[x > cutpoint]
          res <- res * my_coef[index]
        } 
        if(det_value_inner < 0.80 & det_value_inner > 0.33){ 
          res <- log(abs(x))
          res <- res * my_coef[index]
        }
        if(det_value_inner > 0.80){
          sorted_x <- sort(x)
          diff_x <- c(sorted_x) - c(NA,  sorted_x[-length(sorted_x)])
          res <- sorted_x
          for(ijaa in 2:length(diff_x)){ 
            res[ijaa] <- res[ijaa-1] + rnorm(1, 
                                             mean = diff_x[ijaa]*my_coef[index],
                                             sd = sqrt(diff_x[ijaa])*1 )
          }
          res <- res[rank(x)]
        } 
      }    
      if(my_coef[index] == 0){res[res == res] <- 0}
      res <- list( res  )
      return( res )   })   
    my_BinaryLinearMat <- do.call(cbind, my_BinaryLinearMat)
    Y_0 <- rowSums( my_BinaryLinearMat ) + rnorm(n, mean = 0, sd = response_sd)
  }
  
  if(nonLinear_outcome == F){ 
    Y_0 <- c(X %*% my_coef) + rnorm(n, mean = 0, sd = response_sd)
  } 
  Y_1 <- Y_0  + treatment_effect 
  Y_obs <- Y_0
  Y_obs[treatment_indicator==1] <- Y_1[treatment_indicator == 1]
  
  return_list <- list(Y_obs = Y_obs, 
                      treatment_indicator = treatment_indicator, 
                      X = X, 
                      Y_0 = Y_0, 
                      Y_1 = Y_1, 
                      treatment_propensity = treatment_propensity) 
  return( return_list )  
}